{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "da0308d8",
   "metadata": {},
   "source": [
    "# Seq2Seq + Attention\n",
    "\n",
    "- [论文pdf](https://arxiv.org/abs/1409.0473)\n",
    "\n",
    "模型训练特别慢，大约慢了20倍，这是 Bahdanau Attention的设计缺陷，而 Transformer 是解决方案"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc20c872",
   "metadata": {},
   "source": [
    "## autodl 环境准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "543cc8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install datasets transformers scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b4f2ffcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# os.environ['no_proxy'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'\n",
    "# os.environ['NO_PROXY'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'\n",
    "\n",
    "# os.environ['http_proxy'] = 'http://100.72.64.19:12798'\n",
    "# os.environ['HTTP_PROXY'] = 'http://100.72.64.19:12798'\n",
    "\n",
    "# os.environ['https_proxy'] = 'http://100.72.64.19:12798'\n",
    "# os.environ['HTTPS_PROXY'] = 'http://100.72.64.19:12798'\n",
    "os.environ['HF_ENDPOINT']=\"https://hf-mirror.com\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43492e99",
   "metadata": {},
   "source": [
    "## 初始"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "50f4ec5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "beadb6db",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 6\n",
    "MAX_DATA = 50000\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "learning_rate=5e-4\n",
    "weight_decay=1e-5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e5a4e3f",
   "metadata": {},
   "source": [
    "## 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d201a8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from datasets.download import DownloadConfig\n",
    "\n",
    "\n",
    "full_dataset = load_dataset(\"wmt/wmt17\", \"zh-en\", split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2570020f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "从 总共 1141860 个数据中选择了 50000 个\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Salta Province, July 1995;\n",
      "萨尔塔省，1995年7月；\n",
      "Sadly, they are unlikely to end soon. Talk of a “grand bargain” remains just that – talk.\n",
      "不幸的是，这些谈判能在短期内结束的希望颇为渺茫，关于“大妥协”事宜的谈判仍将仅仅停留在口头之上。\n",
      "But ideas matter.\n",
      "但是这个观念很重要。\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "def is_valid_trans(x) -> bool:\n",
    "    en = x[\"translation\"][\"en\"]\n",
    "    zh = x[\"translation\"][\"zh\"]\n",
    "    if not en or not zh:\n",
    "        return False\n",
    "    if len(en) < 3 or len(zh) < 3:\n",
    "        return False\n",
    "    # 核心过滤条件\n",
    "    if len(en) > 100 or len(zh) > 100:\n",
    "        return False\n",
    "    ratio = len(en) / len(zh)\n",
    "    if ratio < 0.5 or ratio > 2:\n",
    "        return False\n",
    "    if not re.search(r'[\\u4e00-\\u9fff]', zh):\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "full_dataset = full_dataset.filter(is_valid_trans, num_proc=10)\n",
    "\n",
    "dataset = full_dataset.select(range(min(MAX_DATA, len(full_dataset))))\n",
    "dataset = dataset.shuffle(seed=20250709)\n",
    "\n",
    "print(f\"从 总共 {len(full_dataset)} 个数据中选择了 {len(dataset)} 个\")\n",
    "\n",
    "sample = dataset.select(range(3))\n",
    "print(\"-\"*100)\n",
    "for i in sample:\n",
    "    print(i[\"translation\"][\"en\"])\n",
    "    print(i[\"translation\"][\"zh\"])\n",
    "    # print(\"-\"*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "49223f55",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class WmtDataset(Dataset):\n",
    "    def __init__(self, hf_dataset, tokenizer_en, tokenizer_zh, max_length=100):\n",
    "        self.dataset = hf_dataset\n",
    "        self.tokenizer_en = tokenizer_en\n",
    "        self.tokenizer_zh = tokenizer_zh\n",
    "        self.max_length = max_length\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        item = self.dataset[idx]\n",
    "        en_text = item[\"translation\"][\"en\"]\n",
    "        zh_text = item[\"translation\"][\"zh\"]\n",
    "\n",
    "        en_tokens = self.tokenizer_en(en_text, \n",
    "                                        max_length=self.max_length, \n",
    "                                        padding='max_length', \n",
    "                                        truncation=True, \n",
    "                                        return_tensors='pt')\n",
    "            \n",
    "        zh_tokens = self.tokenizer_zh(zh_text, \n",
    "                                        max_length=self.max_length, \n",
    "                                        padding='max_length', \n",
    "                                        truncation=True, \n",
    "                                        return_tensors='pt')\n",
    "            \n",
    "        return en_tokens['input_ids'].squeeze(), zh_tokens['input_ids'].squeeze(),\n",
    "        \n",
    "\n",
    "# bert-base-multilingual-cased 有11万单词，太大了\n",
    "encoder_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "decoder_tokenizer = AutoTokenizer.from_pretrained(\"bert-base-chinese\")\n",
    "encoder_vocab_size = encoder_tokenizer.vocab_size\n",
    "decoder_vocab_size = decoder_tokenizer.vocab_size\n",
    "\n",
    "total = len(dataset)\n",
    "split_idx = int(len(dataset)* 0.95)\n",
    "train_hf_dataset = dataset.select(range(0, split_idx))\n",
    "test_hf_dataset = dataset.select(range(split_idx, total))\n",
    "train_data = WmtDataset(train_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)\n",
    "val_data = WmtDataset(test_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)\n",
    "\n",
    "\n",
    "train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)\n",
    "val_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34113774",
   "metadata": {},
   "source": [
    "## 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7c0d5c57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Encoder, self).__init__()\n",
    "        \n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "        \n",
    "        # Embedding layer to convert token IDs to dense vectors\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)\n",
    "        \n",
    "        # GRU layer for processing sequences\n",
    "        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, \n",
    "                        batch_first=True, dropout=dropout, bidirectional=False)\n",
    "        \n",
    "    def forward(self, input_seq):\n",
    "        # Convert token IDs to embeddings\n",
    "        embedded = self.embedding(input_seq)  # [batch_size, seq_len, embed_size]\n",
    "        \n",
    "        # Pass through GRU\n",
    "        outputs, hidden = self.rnn(embedded)\n",
    "        \n",
    "        # outputs: [batch_size, seq_len, hidden_size]\n",
    "        # hidden: [num_layers, batch_size, hidden_size] \n",
    "        \n",
    "        return outputs, hidden\n",
    "\n",
    "class Attention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(Attention, self).__init__()\n",
    "        # e_ij = alpha(s_{t-1}, h_j)\n",
    "        # 这里是： W_a*h_j + U_a*s_{t-1}\n",
    "        self.W_a = nn.Linear(hidden_size, hidden_size, bias=False)\n",
    "        self.U_a = nn.Linear(hidden_size, hidden_size, bias=False)\n",
    "        self.v_a = nn.Linear(hidden_size, 1)\n",
    "    def forward(self, hidden:torch.Tensor, encoder_out:torch.Tensor):\n",
    "        # hidden: [num_layer, batch_size, hidden_size]\n",
    "        # encoder_out: [batch_size, seq_len, hidden_size]\n",
    "        hidden = hidden[-1].unsqueeze(1) # [batch_size, 1, hidden_size]\n",
    "        energy1 = self.W_a(encoder_out) # [batch_size, 1, hidden_size]\n",
    "        energy2 = self.U_a(hidden) # [batch_size, seq_len, hidden_size]\n",
    "        energy = F.tanh(energy1 + energy2)\n",
    "        energy = self.v_a(energy) # [batch_size, seq_len, 1]\n",
    "        # [batch_size, seq_len, 1] --> # [batch_size, 1, seq_len]\n",
    "        attn_weights = F.softmax(energy, dim=1).transpose(1, 2)\n",
    "        return attn_weights\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Decoder, self).__init__()\n",
    "\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)\n",
    "        # 需要 嵌入表示信息、注意力信息\n",
    "        self.rnn = nn.GRU(embed_size + hidden_size, hidden_size, num_layers, \n",
    "                        batch_first=True, dropout=dropout, bidirectional=False)\n",
    "        self.output_projection = nn.Linear(hidden_size, vocab_size)\n",
    "        self.attention = Attention(hidden_size)\n",
    "        \n",
    "    def forward(self, input_token, hidden, encoder_outs):\n",
    "        # input_token: [batch_size, seq_len]\n",
    "        # hidden: [num_layer, batch_size, hidden_size]\n",
    "        \n",
    "        # encoder_out: [batch_size, seq_len, hidden_size]\n",
    "        embedded = self.embedding(input_token)  # [batch_size, seq_len, embed_size]\n",
    "        # 计算注意力权重\n",
    "        attn_weights = self.attention.forward(hidden, encoder_outs)\n",
    "        context = torch.bmm(attn_weights, encoder_outs)\n",
    "        # rnn_input: [batch_size, 1, embed_size + hidden_size]\n",
    "        rnn_input = torch.cat([embedded, context], dim=2)\n",
    "        rnn_out, hidden = self.rnn(rnn_input, hidden)\n",
    "        outputs = self.output_projection(rnn_out)\n",
    "        return outputs, hidden\n",
    "\n",
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder_vocab_size, decoder_vocab_size, embed_size=512, \n",
    "                hidden_size=1024, num_layers=2, dropout=0.1):\n",
    "        super(Seq2Seq, self).__init__()\n",
    "        \n",
    "        self.encoder_vocab_size = encoder_vocab_size\n",
    "        self.decoder_vocab_size = decoder_vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        # Initialize encoder and decoder\n",
    "        self.encoder = Encoder(encoder_vocab_size, embed_size, hidden_size, num_layers, dropout)\n",
    "        # self.attention = nn.Linear()\n",
    "        self.decoder = Decoder(decoder_vocab_size, embed_size, hidden_size, num_layers, dropout)\n",
    "        \n",
    "    def forward(self, source_seq:torch.Tensor, target_seq:torch.Tensor):\n",
    "        batch_size, tgt_len = target_seq.shape\n",
    "        all_outs = torch.zeros(batch_size, tgt_len, self.decoder_vocab_size).to(source_seq.device)\n",
    "        encoder_outs, hidden = self.encoder(source_seq)\n",
    "\n",
    "        decoder_input = target_seq[:, 0].unsqueeze(1)\n",
    "        for i in range(tgt_len):\n",
    "            outs, hidden = self.decoder.forward(decoder_input, hidden, encoder_outs)\n",
    "            all_outs[:, i, :] = outs.squeeze(1)\n",
    "            # teacher forcing\n",
    "            if i < tgt_len - 1:\n",
    "                decoder_input = target_seq[:, i+1].unsqueeze(1)\n",
    "\n",
    "        return all_outs\n",
    "    \n",
    "    def generate(self, source_seq, max_length=100, start_token=101, end_token=102):\n",
    "        self.eval()\n",
    "        batch_size = source_seq.size(0)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            # Encode source sequence\n",
    "            _, hidden = self.encoder(source_seq)\n",
    "            \n",
    "            # Initialize with start token\n",
    "            decoder_input = torch.full((batch_size, 1), start_token, dtype=torch.long).to(source_seq.device)\n",
    "            \n",
    "            # Store generated tokens\n",
    "            generated_tokens = []\n",
    "            \n",
    "            for _ in range(max_length):\n",
    "                # (bs, 1, vocab), (bs, l, hidden)\n",
    "                output, hidden = self.decoder(decoder_input, hidden)\n",
    "                \n",
    "                # Get the token with highest probability\n",
    "                next_token = output.argmax(dim=2)\n",
    "                generated_tokens.append(next_token)\n",
    "                \n",
    "                # Use predicted token as next input\n",
    "                decoder_input = next_token\n",
    "                \n",
    "                # Stop if all sequences generated EOS token\n",
    "                if torch.all(next_token == end_token):\n",
    "                    break\n",
    "            \n",
    "            # Concatenate all generated tokens\n",
    "            generated_seq = torch.cat(generated_tokens, dim=1)\n",
    "            \n",
    "        return generated_seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0f89b8c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "总参数: 13,751,433\n",
      "可训练参数: 13,751,433\n",
      "占用空间: 52.46 MB\n"
     ]
    }
   ],
   "source": [
    "model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)\n",
    "\n",
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f\"总参数: {total_params:,}\")\n",
    "print(f\"可训练参数: {trainable_params:,}\")\n",
    "print(f\"占用空间: {total_params * 4 / 1024 / 1024:.2f} MB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1eeb0ea",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "745bc03d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch= 0\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "796e1c5bab264da88fb5623156d81979",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/372 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12ca4b462c114851a32e9140688f2046",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP1/6, Train Loss:5.894558247699533, Val Loss:5.517913913726806\n",
      "Saved best model with validation loss: 5.5179\n",
      "epoch= 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d7245076608a4707b6a74fdc4b8fe1b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/372 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23ddec08832d4bc8b67ebf2f452e6205",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP2/6, Train Loss:4.9056418044592744, Val Loss:4.402488088607788\n",
      "Saved best model with validation loss: 4.4025\n",
      "epoch= 2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9eef5096e8fe4001895ebf5e8b9e3608",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/372 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ef345989419478d89199f9e1043bf56",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP3/6, Train Loss:3.9667646461917507, Val Loss:3.687987804412842\n",
      "Saved best model with validation loss: 3.6880\n",
      "epoch= 3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a4445cad4e784c5fbf86bbab3a8101c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/372 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1cb3fd64b1064cab9bcbbe16f32cfe31",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP4/6, Train Loss:3.365915233729988, Val Loss:3.2367999792099\n",
      "Saved best model with validation loss: 3.2368\n",
      "epoch= 4\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e57448a8c6924d9295b0404735cd950d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/372 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d0527736fe7340e487bc52c7eb3dbfb2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP5/6, Train Loss:2.9838629576467697, Val Loss:2.9496025681495666\n",
      "Saved best model with validation loss: 2.9496\n",
      "epoch= 5\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16ac93fe2b4e4e269ef5b9ec911299a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train:   0%|          | 0/372 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a11ef289a874ccbb41faf73b2be74d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation:   0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EP6/6, Train Loss:2.7182633553140905, Val Loss:2.7491687417030333\n",
      "Saved best model with validation loss: 2.7492\n"
     ]
    }
   ],
   "source": [
    "# 半精度训练\n",
    "from collections import defaultdict\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from torch import autocast, GradScaler\n",
    "scaler = GradScaler()\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')\n",
    "optimizer = optim.Adam(model.parameters(), \n",
    "                            lr=learning_rate, \n",
    "                            weight_decay=weight_decay)\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)\n",
    "\n",
    "history = defaultdict(list)\n",
    "\n",
    "def train_epoch():\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    progress_bar = tqdm(train_iter, desc=\"Train\")\n",
    "    \n",
    "    for i, (x, y) in enumerate(progress_bar):\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        \n",
    "        with autocast(device, dtype=torch.bfloat16):\n",
    "            pred = model(x, y[:,:-1])\n",
    "            loss = criterion(pred.permute(0, 2, 1), y[:, 1:])\n",
    "        scaler.scale(loss).backward()\n",
    "        scaler.step(optimizer)\n",
    "        scaler.update()\n",
    "            \n",
    "        # pred = model(x, y[:,:-1])\n",
    "        # loss = criterion(pred.permute(0, 2, 1), y[:, 1:])\n",
    "        # loss.backward()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        train_loss += loss.item()\n",
    "        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})\n",
    "    \n",
    "    return train_loss / len(train_iter)\n",
    "\n",
    "def test_epoch():\n",
    "    model.eval()\n",
    "    val_loss = 0\n",
    "    \n",
    "    with torch.no_grad(): \n",
    "        progress_bar = tqdm(val_iter, desc=\"Validation\")\n",
    "        \n",
    "        for x, y in progress_bar:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)        \n",
    "            \n",
    "            pred = model(x, y[:,:-1]) \n",
    "            loss = criterion(pred.permute(0, 2, 1), y[:, 1:]) \n",
    "            val_loss += loss.item()\n",
    "            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})\n",
    "\n",
    "    return val_loss / len(val_iter)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12ddabac",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "best_val_loss = np.inf\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    print(\"epoch=\", epoch)\n",
    "    train_loss = train_epoch()\n",
    "    history['train_loss'].append(train_loss)\n",
    "\n",
    "    val_loss = test_epoch()\n",
    "    history['val_loss'].append(val_loss)\n",
    "    scheduler.step(val_loss)\n",
    "    current_lr = optimizer.param_groups[0]['lr']\n",
    "\n",
    "    print(\"EP{}/{}, Train Loss:{}, Val Loss:{}\".format(epoch+1, NUM_EPOCHS, history['train_loss'][-1], history['val_loss'][-1]))\n",
    "\n",
    "    if val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "        torch.save(model.state_dict(), \"best_seq2seq_attention.pth\")\n",
    "        print(f\"Saved best model with validation loss: {best_val_loss:.4f}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60fccd6c",
   "metadata": {},
   "source": [
    "## 可视化"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6450214d",
   "metadata": {},
   "source": [
    "## 翻译"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ff30d269",
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(\n",
    "    model: Seq2Seq,\n",
    "    english_text: str,\n",
    "    encoder_tokenizer,\n",
    "    decoder_tokenizer,\n",
    "    max_length: int = 50,\n",
    "):\n",
    "    \"\"\"\n",
    "    使用指定的Seq2Seq模型将英文文本翻译为中文。\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "\n",
    "    input_token = encoder_tokenizer(\n",
    "        english_text,\n",
    "        max_length=max_length,\n",
    "        padding=\"max_length\",\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "    source_seq = input_token[\"input_ids\"].to(device)\n",
    "\n",
    "    start_token_id = decoder_tokenizer.cls_token_id\n",
    "    end_token_id = decoder_tokenizer.sep_token_id\n",
    "\n",
    "    generated_ids = model.generate(\n",
    "        source_seq,\n",
    "        max_length=max_length,\n",
    "        start_token=start_token_id,\n",
    "        end_token=end_token_id,\n",
    "    )\n",
    "    generated_text = decoder_tokenizer.decode(\n",
    "        generated_ids[0], skip_special_tokens=True\n",
    "    )\n",
    "\n",
    "    # BERT中文tokenizer解码后可能在字与字之间添加空格\n",
    "    return generated_text.replace(\" \", \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4170ee44",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Decoder.forward() missing 1 required positional argument: 'encoder_outs'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[12], line 15\u001b[0m\n\u001b[1;32m      7\u001b[0m speed_test_cases \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m      8\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m      9\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHow are you doing today?\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mI really enjoy reading books and learning about different cultures around the world.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     11\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     12\u001b[0m ]\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m sentence \u001b[38;5;129;01min\u001b[39;00m speed_test_cases:\n\u001b[0;32m---> 15\u001b[0m     \u001b[38;5;28mprint\u001b[39m(sentence, \u001b[43mtranslate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msentence\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoder_tokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdecoder_tokenizer\u001b[49m\u001b[43m)\u001b[49m)\n",
      "Cell \u001b[0;32mIn[11], line 25\u001b[0m, in \u001b[0;36mtranslate\u001b[0;34m(model, english_text, encoder_tokenizer, decoder_tokenizer, max_length)\u001b[0m\n\u001b[1;32m     22\u001b[0m start_token_id \u001b[38;5;241m=\u001b[39m decoder_tokenizer\u001b[38;5;241m.\u001b[39mcls_token_id\n\u001b[1;32m     23\u001b[0m end_token_id \u001b[38;5;241m=\u001b[39m decoder_tokenizer\u001b[38;5;241m.\u001b[39msep_token_id\n\u001b[0;32m---> 25\u001b[0m generated_ids \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     26\u001b[0m \u001b[43m    \u001b[49m\u001b[43msource_seq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     27\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     28\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstart_token\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstart_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     29\u001b[0m \u001b[43m    \u001b[49m\u001b[43mend_token\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mend_token_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     30\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     31\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m decoder_tokenizer\u001b[38;5;241m.\u001b[39mdecode(\n\u001b[1;32m     32\u001b[0m     generated_ids[\u001b[38;5;241m0\u001b[39m], skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m     33\u001b[0m )\n\u001b[1;32m     35\u001b[0m \u001b[38;5;66;03m# BERT中文tokenizer解码后可能在字与字之间添加空格\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[8], line 130\u001b[0m, in \u001b[0;36mSeq2Seq.generate\u001b[0;34m(self, source_seq, max_length, start_token, end_token)\u001b[0m\n\u001b[1;32m    126\u001b[0m generated_tokens \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_length):\n\u001b[1;32m    129\u001b[0m     \u001b[38;5;66;03m# (bs, 1, vocab), (bs, l, hidden)\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m     output, hidden \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdecoder_input\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    132\u001b[0m     \u001b[38;5;66;03m# Get the token with highest probability\u001b[39;00m\n\u001b[1;32m    133\u001b[0m     next_token \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "\u001b[0;31mTypeError\u001b[0m: Decoder.forward() missing 1 required positional argument: 'encoder_outs'"
     ]
    }
   ],
   "source": [
    "model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)\n",
    "\n",
    "\n",
    "state_dict = torch.load(\"./best_seq2seq_attention.pth\", weights_only=True)\n",
    "model.load_state_dict(state_dict)\n",
    "\n",
    "speed_test_cases = [\n",
    "    \"Hello.\",\n",
    "    \"How are you doing today?\",\n",
    "    \"I really enjoy reading books and learning about different cultures around the world.\",\n",
    "    \"The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky.\"\n",
    "]\n",
    "\n",
    "for sentence in speed_test_cases:\n",
    "    print(sentence, translate(model, sentence, encoder_tokenizer, decoder_tokenizer))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
